{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Intro to combining plots\n",
    "========================\n",
    "\n",
    "In this notebook you will learn how to combine plots in a simple way."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Types of multiple plots\n",
    "\n",
    "There are three ways of combining your plots in the visualization framework, each with its associated class:\n",
    "\n",
    "- `\"multiple\"`: it's the most basic one. It just takes the drawinfs from all plots and displays them in the same plot.\n",
    "- `\"subplots\"`: Creates a grid of subplots, where each item of the grid contains a plot.\n",
    "- `\"multiple_x\"` and `\"multiple_y\"` (multiple_A): Creates a plot where a separate A axis is created for each plot, while the rest of axes are shared.\n",
    "- `\"animation\"`: Creates an animation where each child plot is represented in a frame.\n",
    "\n",
    "They can all be acheived with the `merge_plots` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sisl.viz import merge_plots"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a simple tight-binding model for *hBN* to experiment with it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sisl\n",
    "import numpy as np\n",
    "\n",
    "r = np.linspace(0, 3.5, 50)\n",
    "f = np.exp(-r)\n",
    "\n",
    "orb = sisl.AtomicOrbital(\"2pzZ\", (r, f))\n",
    "geom = sisl.geom.graphene(\n",
    "    orthogonal=False, atoms=[sisl.Atom(5, orb), sisl.Atom(7, orb)]\n",
    ")\n",
    "geom = geom.move([0, 0, 5])\n",
    "H = sisl.Hamiltonian(geom)\n",
    "H.construct(\n",
    "    [(0.1, 1.44), (0, -2.7)],\n",
    ")\n",
    "H[0, 0] = -0.7\n",
    "H[1, 1] = 0.7"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Individual plots\n",
    "\n",
    "As an example, from the hamiltonian that we constructed, let's build a bands plot and a pdos plot:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "band_structure = sisl.BandStructure(\n",
    "    H,\n",
    "    [[0, 0, 0], [0, 0.5, 0], [1 / 3, 2 / 3, 0], [0, 0, 0]],\n",
    "    400,\n",
    "    [r\"Gamma\", r\"M\", r\"K\", r\"Gamma\"],\n",
    ")\n",
    "bands_plot = band_structure.plot()\n",
    "pdos_plot = H.plot.pdos(\n",
    "    data_Erange=[-10, 10], Erange=[-10, 10], kgrid=[121, 121, 1], nE=1000\n",
    ").split_DOS(name=\"$species\")\n",
    "\n",
    "plots = [bands_plot, pdos_plot]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's check the plots individually:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bands_plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pdos_plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now, we will merge them.\n",
    "\n",
    "Merging into a single plot\n",
    "----"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merge_plots(*plots)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, `merge_plots` uses the `\"multiple\"` method to merge the plots. In this case, it is not very nice, because the two axes are different for bands and pdos.\n",
    "\n",
    "However, they have one axis in common! The energy axis. We can use this fact to combine them in a way that they share the energy axis but have each a separate one for the other axis. \n",
    "\n",
    "Independent axes\n",
    "-------------\n",
    "\n",
    "First, we need to make sure that both energy axis are on the X or Y axis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pdos_plot = pdos_plot.update_inputs(E_axis=\"y\")\n",
    "bands_plot = bands_plot.update_inputs(E_axis=\"y\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And then we can use `multiple_x` so that each plot has a separate X axis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merge_plots(*plots, composite_method=\"multiple_x\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Much better, right? Now we can easily see that B contributes more to the bottom band, while N contributes more to the top band."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Subplots\n",
    "--------\n",
    "\n",
    "Let's try now to use the `\"subplots\"` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merge_plots(*plots, composite_method=\"subplots\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default it puts one plot on each row, but we can manage that with the arguments `rows` (number of rows), `cols` (number of columns), and `arrange` (if rows or cols are missing, way to determine the missing value, can be \"rows\", \"cols\" or \"square\").\n",
    "\n",
    "Let's put the two plots in separate columns:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "nbsphinx-thumbnail"
    ]
   },
   "outputs": [],
   "source": [
    "merge_plots(*plots, composite_method=\"subplots\", cols=2).show(\"png\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "    \n",
    "Note\n",
    "    \n",
    "There is also the `sisl.viz.subplots` function, which might be more convenient to use in a notebook because its help message is more helpful.\n",
    "    \n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Merging merged plots\n",
    "-------------------\n",
    "\n",
    "We can recursively merge plots. Unfortunately however, for the moment only the top level merge method is taken into account. The other levels are simply taken as `\"multiple\"`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_plot = merge_plots(*plots, composite_method=\"multiple_x\")\n",
    "\n",
    "merge_plots(\n",
    "    merged_plot, bands_plot, composite_method=\"subplots\", cols=2, backend=\"plotly\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the future, separate axes within subplots might be supported."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Animations\n",
    "----------\n",
    "\n",
    "Animations can be very cool but they are sometimes hard to build. `merge_plots` makes it as easy as possible for you, you just need to use the `\"animation\"` method.\n",
    "\n",
    "Let's create an animation to see the convergence of graphene's PDOS with the number of k points. We first create the plots:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the number of k points that we are going to try.\n",
    "# Do 1 by 1 from 1 to 12 and then in steps of 5 from 15 to 90.\n",
    "ks = [*np.arange(1, 12), *np.arange(15, 90, 5)]\n",
    "\n",
    "# Generate all plots.\n",
    "# We use the scatter trace instead of a line because it looks better in animations :)\n",
    "pdos_plots = [\n",
    "    H.plot.pdos(\n",
    "        data_Erange=[-10, 10],\n",
    "        Erange=[-10, 10],\n",
    "        kgrid=[k, k, 1],\n",
    "        nE=1000,\n",
    "        line_mode=\"scatter\",\n",
    "        line_scale=2,\n",
    "    ).split_DOS(name=\"$species\")\n",
    "    for k in ks\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now all the heavy computation is done! We can merge the plots into an animation, using the ks as frame names. Other arguments that you can pass to an animation are `frame_duration` (in ms), `transition` (in ms) and `redraw` (Wether to redraw the whole plot for each frame).\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "    \n",
    "Note\n",
    "    \n",
    "We suggest that you go to the last frame and click the house icon to set the y axis range. Then press play and see the PDOS converge!\n",
    "    \n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merge_plots(*pdos_plots, composite_method=\"animation\", frame_names=ks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "    \n",
    "Note\n",
    "    \n",
    "There is also the `sisl.viz.animation` function, which might be more convenient to use in a notebook because its help message is more helpful.\n",
    "    \n",
    "</div>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
